# Copyright (C) 2021 Intel Corporation
#
# SPDX-License-Identifier: MIT

import os
import os.path as osp

from defusedxml import ElementTree as ET

from datumaro.components.annotation import (
    AnnotationType, Cuboid3d, LabelCategories,
)
from datumaro.components.extractor import DatasetItem, Importer, SourceExtractor
from datumaro.components.format_detection import FormatDetectionContext
from datumaro.util import cast
from datumaro.util.image import find_images
from datumaro.util.meta_file_util import has_meta_file, parse_meta_file

from .format import KittiRawPath, OcclusionStates, TruncationStates


class KittiRawExtractor(SourceExtractor):
    # http://www.cvlibs.net/datasets/kitti/raw_data.php
    # https://s3.eu-central-1.amazonaws.com/avg-kitti/devkit_raw_data.zip
    # Check cpp header implementation for field meaning

    def __init__(self, path, subset=None):
        assert osp.isfile(path), path
        self._rootdir = osp.dirname(path)

        super().__init__(subset=subset)

        items, categories = self._parse(path)
        self._items = list(self._load_items(items).values())
        self._categories = categories

    @classmethod
    def _parse(cls, path):
        tracks = []
        track = None
        shape = None
        attr = None
        labels = {}
        point_tags = {'tx', 'ty', 'tz', 'rx', 'ry', 'rz'}

        # Can fail with "XML declaration not well-formed" on documents with
        # <?xml ... standalone="true"?>
        #                       ^^^^
        # (like the original Kitti dataset), while
        # <?xml ... standalone="yes"?>
        #                       ^^^
        # works.
        tree = ET.iterparse(path, events=("start", "end"))
        for ev, elem in tree:
            if ev == "start":
                if elem.tag == 'item':
                    if track is None:
                        track = {
                            'shapes': [],
                            'scale': {},
                            'label': None,
                            'attributes': {},
                            'start_frame': None,
                            'length': None,
                        }
                    else:
                        shape = {
                            'points': {},
                            'attributes': {},
                            'occluded': None,
                            'occluded_kf': False,
                            'truncated': None,
                        }

                elif elem.tag == 'attribute':
                    attr = {}

            elif ev == "end":
                if elem.tag == 'item':
                    assert track is not None

                    if shape:
                        track['shapes'].append(shape)
                        shape = None
                    else:
                        assert track['length'] == len(track['shapes'])

                        if track['label']:
                            labels.setdefault(track['label'], set())

                            for a in track['attributes']:
                                labels[track['label']].add(a)

                            for s in track['shapes']:
                                for a in s['attributes']:
                                    labels[track['label']].add(a)

                        tracks.append(track)
                        track = None

                # track tags
                elif track and elem.tag == 'objectType':
                    track['label'] = elem.text
                elif track and elem.tag in {'h', 'w', 'l'}:
                    track['scale'][elem.tag] = float(elem.text)
                elif track and elem.tag == 'first_frame':
                    track['start_frame'] = int(elem.text)
                elif track and elem.tag == 'count' and track:
                    track['length'] = int(elem.text)

                # pose tags
                elif shape and elem.tag in point_tags:
                    shape['points'][elem.tag] = float(elem.text)
                elif shape and elem.tag == 'occlusion':
                    shape['occluded'] = OcclusionStates(int(elem.text))
                elif shape and elem.tag == 'occlusion_kf':
                    shape['occluded_kf'] = elem.text == '1'
                elif shape and elem.tag == 'truncation':
                    shape['truncated'] = TruncationStates(int(elem.text))

                # common tags
                elif attr is not None and elem.tag == 'name':
                    if not elem.text:
                        raise ValueError("Attribute name can't be empty")
                    attr['name'] = elem.text
                elif attr is not None and elem.tag == 'value':
                    attr['value'] = elem.text or ''
                elif attr is not None and elem.tag == 'attribute':
                    if shape:
                        shape['attributes'][attr['name']] = attr['value']
                    else:
                        track['attributes'][attr['name']] = attr['value']
                    attr = None

        if track is not None or shape is not None or attr is not None:
            raise Exception("Failed to parse anotations from '%s'" % path)

        special_attrs = KittiRawPath.SPECIAL_ATTRS
        common_attrs = ['occluded']

        if has_meta_file(path):
            categories =  { AnnotationType.label: LabelCategories.
                from_iterable(parse_meta_file(path).keys()) }
        else:
            label_cat = LabelCategories(attributes=common_attrs)
            for label, attrs in sorted(labels.items(), key=lambda e: e[0]):
                label_cat.add(label, attributes=set(attrs) - special_attrs)

            categories = {AnnotationType.label: label_cat}

        items = {}
        for idx, track in enumerate(tracks):
            track_id = idx + 1
            for i, ann in enumerate(
                    cls._parse_track(track_id, track, categories)):
                frame_desc = items.setdefault(track['start_frame'] + i,
                    {'annotations': []})
                frame_desc['annotations'].append(ann)

        return items, categories

    @classmethod
    def _parse_attr(cls, value):
        if value == 'true':
            return True
        elif value == 'false':
            return False
        elif str(cast(value, int, 0)) == value:
            return int(value)
        elif str(cast(value, float, 0)) == value:
            return float(value)
        else:
            return value

    @classmethod
    def _parse_track(cls, track_id, track, categories):
        common_attrs = { k: cls._parse_attr(v)
            for k, v in track['attributes'].items() }
        scale = [track['scale'][k] for k in ['w', 'h', 'l']]
        label = categories[AnnotationType.label].find(track['label'])[0]

        kf_occluded = False
        for shape in track['shapes']:
            occluded = shape['occluded'] in {
                OcclusionStates.FULLY, OcclusionStates.PARTLY}
            if shape['occluded_kf']:
                kf_occluded = occluded
            elif shape['occluded'] == OcclusionStates.OCCLUSION_UNSET:
                occluded = kf_occluded

            if shape['truncated'] in {TruncationStates.OUT_IMAGE,
                    TruncationStates.BEHIND_IMAGE}:
                # skip these frames
                continue

            local_attrs = { k: cls._parse_attr(v)
                for k, v in shape['attributes'].items() }
            local_attrs['occluded'] = occluded
            local_attrs['track_id'] = track_id
            attrs = dict(common_attrs)
            attrs.update(local_attrs)

            position = [shape['points'][k] for k in ['tx', 'ty', 'tz']]
            rotation = [shape['points'][k] for k in ['rx', 'ry', 'rz']]

            yield Cuboid3d(position, rotation, scale, label=label,
                attributes=attrs)

    @staticmethod
    def _parse_name_mapping(path):
        rootdir = osp.dirname(path)

        name_mapping = {}
        if osp.isfile(path):
            with open(path, encoding='utf-8') as f:
                for line in f:
                    line = line.strip()
                    if not line or line.startswith('#'):
                        continue

                    idx, path = line.split(maxsplit=1)
                    path = osp.abspath(osp.join(rootdir, path))
                    assert path.startswith(rootdir), path
                    path = osp.relpath(path, rootdir)
                    name_mapping[int(idx)] = path

        return name_mapping

    def _load_items(self, parsed):
        images = {}
        for d in os.listdir(self._rootdir):
            image_dir = osp.join(self._rootdir, d, 'data')
            if not (d.lower().startswith(KittiRawPath.IMG_DIR_PREFIX) and \
                    osp.isdir(image_dir)):
                continue

            for p in find_images(image_dir, recursive=True):
                image_name = osp.splitext(osp.relpath(p, image_dir))[0]
                images.setdefault(image_name, []).append(p)

        name_mapping = self._parse_name_mapping(
            osp.join(self._rootdir, KittiRawPath.NAME_MAPPING_FILE))

        items = {}
        for frame_id, item_desc in parsed.items():
            name = name_mapping.get(frame_id, '%010d' % int(frame_id))
            items[frame_id] = DatasetItem(id=name, subset=self._subset,
                point_cloud=osp.join(self._rootdir,
                    KittiRawPath.PCD_DIR, name + '.pcd'),
                related_images=sorted(images.get(name, [])),
                annotations=item_desc.get('annotations'),
                attributes={'frame': int(frame_id)})

        for frame_id, name in name_mapping.items():
            if frame_id in items:
                continue

            items[frame_id] = DatasetItem(id=name, subset=self._subset,
                point_cloud=osp.join(self._rootdir,
                    KittiRawPath.PCD_DIR, name + '.pcd'),
                related_images=sorted(images.get(name, [])),
                attributes={'frame': int(frame_id)})

        return items


class KittiRawImporter(Importer):
    @classmethod
    def detect(cls, context: FormatDetectionContext) -> None:
        annot_file = context.require_file('*.xml')

        with context.probe_text_file(
            annot_file, "must be a KITTI-like annotation file",
        ) as f:
            parser = ET.iterparse(f, events=('start',))

            _, elem = next(parser)
            if elem.tag != 'boost_serialization':
                raise Exception

            _, elem = next(parser)
            if elem.tag != 'tracklets':
                raise Exception

    @classmethod
    def find_sources(cls, path):
        return cls._find_sources_recursive(path, '.xml', 'kitti_raw')
